package command

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"os"
	"redis-check/conf"
	"redis-check/tool"
	"time"

	"redis-check/client"
	"redis-check/common"
	"redis-check/handle"
	"redis-check/metric"

	_ "github.com/logoove/sqlite"
)

type CompareOptions struct {
	common.BaseOptions
	Source          conf.RedisArgs `group:"source" description:"源Redis配置信息" namespace:"s"`
	Target          conf.RedisArgs `group:"target" description:"目的Redis配置信息" namespace:"t"`
	Match           string         `long:"match" default:"" description:"源Redis扫描Key的匹配模式，未指定时扫描所有Key"`
	DBFile          string         `short:"d" long:"db" default:"result.db" description:"Sqlite3 DB文件名，如果已经存在，则覆盖此文件"`
	ResultFile      string         `long:"result"  description:"结果文件地址，文件头格式：'db\tdiff-type\tkey\tfield'"`
	CompareTimes    int            `long:"times" default:"1" description:"比较次数，下次比较基于上次的结果，值范围[1, 5]"`
	Option          string         `short:"o" long:"option" default:"" description:"处理选项，公共选项（k-指定Key），preview选项（n-不在目标Redis中，e-在目标Redis中，b-大Key, m-小Key） compare选项：（l-比较长度，a-比较值，v-比较小Key的值和大Key的长度）"`
	Id              string         `long:"id" default:"unknown" description:"Metric 的Id"`
	JobId           string         `long:"jobid" default:"unknown" description:"Metric 的JobId"`
	TaskId          string         `long:"taskid" default:"unknown" description:"Metric 的TaskId"`
	Qps             int            `short:"q" long:"qps" default:"1000" description:"最大处理QPS，，如果qps=10，则获取处理速度为 10 * BatchCount每秒"`
	Interval        int            `long:"interval" default:"5" description:"比较处理间隔，CompareTimes > 0时有效，单位：秒，值范围[0, 1000]"`
	BatchCount      int            `long:"batch" default:"256" description:"批量处理大小，值范围[1, 10000]"`
	Parallel        int            `long:"parallel" default:"1" description:"并发比较数量，值范围[1, 100]"`
	MetricPrint     bool           `long:"metric" description:"是否打印统计日志"`
	BigKeyThreshold int64          `long:"bigkey" default:"5120" description:"大Key阀值"`
	KeyList         string         `short:"k" long:"keylist" default:"" description:"指定Key列表，例如: 'abc|efg|dkg'"`
	FilterList      string         `short:"f" long:"filterlist" default:"" description:"Key过滤列表，例如: 'abc*|efg|m*' 匹配 'abc'，'abc1'，'efg'，'m'，'mxyz'，过滤 'efgh'，'p'"`
}

type CompareCommand struct {
	Options            *CompareOptions
	SourceHost         client.RedisHost
	TargetHost         client.RedisHost
	Stat               metric.Stat
	totalConflict      int64
	totalKeyConflict   int64
	totalFieldConflict int64
	times              int
	tickerStat         *time.Ticker
	targetClient       *client.RedisClient
}

func (p *CompareCommand) Execute(args []string) error {
	ret, err := common.RunWithArgs(p.Options, os.Args[2:])
	if !ret {
		return err
	}
	p.SourceHost = p.Options.Source.ToRedisHost("source", true)
	p.TargetHost = p.Options.Target.ToRedisHost("target", true)
	handler := p.CreateHandler()
	for p.times = 1; p.times <= p.Options.CompareTimes; p.times++ {
		if p.times != 1 {
			common.Logger.Infof("稍%d等秒后启动第%d次处理", p.Options.Interval, p.times)
			time.Sleep(time.Second * time.Duration(p.Options.Interval))
		}
		common.Logger.Infof("---------------- 开始第%d次处理", p.times)
		p.HandleKeyCompare(*handler)
		if p.times < p.Options.CompareTimes { // do not reset when run the final time
			p.Stat.Reset(true)
		}
	} // end for
	p.Stat.Reset(false)
	common.Logger.Infof("--------------- 处理完成 ----------------")
	return nil
}

func (p *CompareCommand) CreateHandler() *handle.ICompareHandler {
	context := handle.CompareContext{Stat: &p.Stat, BatchCount: p.Options.BatchCount}
	var handler handle.ICompareHandler
	if p.Options.Option == "len" { // 比较长度
		handler = handle.NewValueLengthComparator(context)
	} else if p.Options.Option == "full" { // 比较值，忽略大Key
		handler = handle.NewFullValueComparator(context, true)
	} else if p.Options.Option == "auto" { // 比较值，不忽略大Key
		handler = handle.NewFullValueComparator(context, false)
	} else {
		handler = handle.NewKeyExistsComparator(context)
	}
	return &handler
}

func (p *CompareCommand) HandleKeyCompare(handler handle.ICompareHandler) {
	redisScanner := handle.RedisScanner{
		Host:          p.SourceHost,
		BatchCount:    p.Options.BatchCount,
		MatchPattern:  p.Options.Match,
		KeyList:       common.RegSplit(p.Options.KeyList, "|"),
		HandleThreads: p.Options.Parallel,
		Qps:           p.Options.Qps,
	}
	compareDb := handle.NewCompareDb(p.Options.DBFile, p.Options.ResultFile, p.Options.BatchCount, p.times)
	defer compareDb.Destroy()
	redisScanner.StartScanRedis(func(state *handle.ScanState) {
		if state.Run {
			p.Stat.Reset(false)
			p.tickerStat = time.NewTicker(time.Second * common.StatRollFrequency)
			p.targetClient = tool.RedisConnectDB(p.TargetHost, state.CurrDb)
			compareDb.StartSave(state.CurrDb)
			go func(ctx context.Context) {
				defer func() {
					p.tickerStat.Stop()
				}()
				for range p.tickerStat.C {
					select { // 判断是否结束
					case <-ctx.Done():
						return
					default:
					}
					p.Stat.Rotate()
					p.PrintStat(false, state.CurrDb, state.DbKey)
				}
			}(*state.Context)
		} else {
			if p.targetClient != nil {
				p.targetClient.Close()
			}
			compareDb.FinishSave()
		}
	}, func(db int32, keyQueue chan<- []*common.Key) {
		if p.times == 1 {
			redisScanner.ScanPhysicalDB(db, keyQueue)
		} else {
			compareDb.ScanLastKeys(db, keyQueue)
		}
	}, func(redisClient *client.RedisClient, i int32, keys []*common.Key) {
		handler.HandleCompare(keys, compareDb.ConflictQueue(), redisClient, p.targetClient)
	})
}

func (p *CompareCommand) PrintStat(finished bool, currentDb int32, dbKeys int64) {
	var buf bytes.Buffer
	var metricStat *metric.Metric

	if p.times == 1 {
		metricStat = &metric.Metric{
			CompareTimes:       p.times,
			Db:                 currentDb,
			DbKeys:             dbKeys,
			OneCompareFinished: finished,
			AllFinished:        false,
			Timestamp:          time.Now().Unix(),
			DateTime:           time.Now().Format("2006-01-02T15:04:05Z")}
		fmt.Fprintf(&buf, "times:%d, db:%d, dbkeys:%d, finished:%v\n", p.times, currentDb, dbKeys, finished)
	} else {
		metricStat = &metric.Metric{
			CompareTimes:       p.times,
			Db:                 currentDb,
			OneCompareFinished: finished,
			AllFinished:        false,
			Timestamp:          time.Now().Unix(),
			DateTime:           time.Now().Format("2006-01-02T15:04:05Z")}
		fmt.Fprintf(&buf, "times:%d, db:%d, finished:%v\n", p.times, currentDb, finished)
	}
	p.totalConflict = int64(0)
	p.totalKeyConflict = int64(0)
	p.totalFieldConflict = int64(0)

	metricStat.KeyMetric = make(map[string]map[string]*metric.CounterStat)

	// fmt.Fprintf(&buf, "--- key equal ---\n")
	for i := common.KeyTypeIndex(0); i < common.EndKeyTypeIndex; i++ {
		metricStat.KeyMetric[i.String()] = make(map[string]*metric.CounterStat)
		if p.Stat.ConflictKey[i][common.NoneConflict].Total() != 0 {
			metricStat.KeyMetric[i.String()]["equal"] = p.Stat.ConflictKey[i][common.NoneConflict].Json()
			if p.times == p.Options.CompareTimes {
				fmt.Fprintf(&buf, "KeyEqualAtLast|%s|%s|%v\n", i, common.NoneConflict,
					p.Stat.ConflictKey[i][common.NoneConflict])
			} else {
				fmt.Fprintf(&buf, "KeyEqualInProcess|%s|%s|%v\n", i, common.NoneConflict,
					p.Stat.ConflictKey[i][common.NoneConflict])
			}
		}
	}
	// fmt.Fprintf(&buf, "--- key conflict ---\n")
	for i := common.KeyTypeIndex(0); i < common.EndKeyTypeIndex; i++ {
		for j := common.ConflictType(0); j < common.NoneConflict; j++ {
			// fmt.Println(i, j, p.statItems.ConflictKey[i][j].Total())
			if p.Stat.ConflictKey[i][j].Total() != 0 {
				metricStat.KeyMetric[i.String()][j.String()] = p.Stat.ConflictKey[i][j].Json()
				if p.times == p.Options.CompareTimes {
					fmt.Fprintf(&buf, "KeyConflictAtLast|%s|%s|%v\n", i, j, p.Stat.ConflictKey[i][j])
					p.totalKeyConflict += p.Stat.ConflictKey[i][j].Total()
				} else {
					fmt.Fprintf(&buf, "KeyConflictInProcess|%s|%s|%v\n", i, j, p.Stat.ConflictKey[i][j])
				}
			}
		}
	}

	metricStat.FieldMetric = make(map[string]map[string]*metric.CounterStat)
	// fmt.Fprintf(&buf, "--- field equal ---\n")
	for i := common.KeyTypeIndex(0); i < common.EndKeyTypeIndex; i++ {
		metricStat.FieldMetric[i.String()] = make(map[string]*metric.CounterStat)
		if p.Stat.ConflictField[i][common.NoneConflict].Total() != 0 {
			metricStat.FieldMetric[i.String()]["equal"] = p.Stat.ConflictField[i][common.NoneConflict].Json()
			if p.times == p.Options.CompareTimes {
				fmt.Fprintf(&buf, "FieldEqualAtLast|%s|%s|%v\n", i, common.NoneConflict,
					p.Stat.ConflictField[i][common.NoneConflict])
			} else {
				fmt.Fprintf(&buf, "FieldEqualInProcess|%s|%s|%v\n", i, common.NoneConflict,
					p.Stat.ConflictField[i][common.NoneConflict])
			}
		}
	}
	// fmt.Fprintf(&buf, "--- field conflict  ---\n")
	for i := common.KeyTypeIndex(0); i < common.EndKeyTypeIndex; i++ {
		for j := common.ConflictType(0); j < common.NoneConflict; j++ {
			if p.Stat.ConflictField[i][j].Total() != 0 {
				metricStat.FieldMetric[i.String()][j.String()] = p.Stat.ConflictField[i][j].Json()
				if p.times == p.Options.CompareTimes {
					fmt.Fprintf(&buf, "FieldConflictAtLast|%s|%s|%v\n", i, j, p.Stat.ConflictField[i][j])
					p.totalFieldConflict += p.Stat.ConflictField[i][j].Total()
				} else {
					fmt.Fprintf(&buf, "FieldConflictInProcess|%s|%s|%v\n", i, j, p.Stat.ConflictField[i][j])
				}
			}
		}
	}

	p.totalConflict = p.totalKeyConflict + p.totalFieldConflict
	if p.Options.MetricPrint {
		metricstr, _ := json.Marshal(metricStat)
		common.Logger.Info(string(metricstr))

		if p.times == p.Options.CompareTimes && finished {
			metricStat.AllFinished = true
			metricStat.Process = int64(100)
			metricStat.TotalConflict = p.totalConflict
			metricStat.TotalKeyConflict = p.totalKeyConflict
			metricStat.TotalFieldConflict = p.totalFieldConflict

			metricstr, _ := json.Marshal(metricStat)
			common.Logger.Info(string(metricstr))
		}
	} else {
		common.Logger.Infof("statItems:\n%s", string(buf.Bytes()))
	}
}
